{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "omegafold.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/omegafold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#**OmegaFold**\n",
        "for more details see: [Github](https://github.com/HeliXonProtein/OmegaFold), [Preprint](https://www.biorxiv.org/content/10.1101/2022.07.21.500999v1)\n",
        "\n",
        "#### **Tips and Instructions**\n",
        "- click the little ▶ play icon to the left of each cell below.\n",
        "- use \"/\" to specify chainbreaks, (eg. sequence=\"AAA/AAA\")\n",
        "- for homo-oligomeric predictions, set copies > 1\n"
      ],
      "metadata": {
        "id": "qEgffIPOyEgk"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "oVkGfduwiAkw"
      },
      "outputs": [],
      "source": [
        "#@markdown ##Install\n",
        "import os,sys,re\n",
        "from IPython.utils import io\n",
        "if \"SETUP_DONE\" not in dir():\n",
        "  import torch\n",
        "  device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "  with io.capture_output() as captured:\n",
        "    if not os.path.isdir(\"OmegaFold\"):\n",
        "      %shell git clone --branch beta --quiet https://github.com/sokrypton/OmegaFold.git\n",
        "      # %shell cd OmegaFold; pip -q install -r requirements.txt\n",
        "      %shell pip -q install py3Dmol biopython\n",
        "      %shell apt-get install aria2 -qq > /dev/null\n",
        "      %shell aria2c -q -x 16 https://helixon.s3.amazonaws.com/release1.pt\n",
        "      %shell mkdir -p ~/.cache/omegafold_ckpt\n",
        "      %shell mv release1.pt ~/.cache/omegafold_ckpt/model.pt\n",
        "  SETUP_DONE = True"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown ##Run **OmegaFold**\n",
        "from string import ascii_uppercase, ascii_lowercase\n",
        "import hashlib\n",
        "\n",
        "def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()\n",
        "alphabet_list = list(ascii_uppercase+ascii_lowercase)\n",
        "\n",
        "jobname = \"test\" #@param {type:\"string\"}\n",
        "jobname = re.sub(r'\\W+', '', jobname)[:50]\n",
        "\n",
        "sequence = \"PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK\" #@param {type:\"string\"}\n",
        "sequence = re.sub(\"[^A-Z:]\", \"\", sequence.replace(\"/\",\":\").upper())\n",
        "sequence = re.sub(\":+\",\":\",sequence)\n",
        "sequence = re.sub(\"^[:]+\",\"\",sequence)\n",
        "sequence = re.sub(\"[:]+$\",\"\",sequence)\n",
        "\n",
        "copies = 1 #@param {type:\"integer\"}\n",
        "sequence = \":\".join([sequence] * copies)\n",
        "#@markdown **Advanced Options**\n",
        "num_cycle = 4 #@param [\"1\", \"2\", \"4\", \"8\", \"16\", \"32\"] {type:\"raw\"}\n",
        "offset_rope = False #@param {type:\"boolean\"}\n",
        "\n",
        "ID = jobname+\"_\"+get_hash(sequence)[:5]\n",
        "seqs = sequence.split(\":\")\n",
        "lengths = [len(s) for s in seqs]\n",
        "\n",
        "def get_subbatch_size(L):\n",
        "  if L <  500: return 500\n",
        "  if L < 1000: return 200\n",
        "  return 150\n",
        "subbatch_size = get_subbatch_size(sum(lengths))\n",
        "\n",
        "u_seqs = list(set(seqs))\n",
        "\n",
        "if len(seqs) == 1: mode = \"mono\"\n",
        "elif len(u_seqs) == 1: mode = \"homo\"\n",
        "else: mode = \"hetero\"\n",
        "\n",
        "with open(f\"{ID}.fasta\",\"w\") as out:\n",
        "  out.write(f\">{ID}\\n{sequence}\\n\")\n",
        "\n",
        "%shell python OmegaFold/main.py --offset_rope={offset_rope} --device={device} --subbatch_size={subbatch_size} --num_cycle={num_cycle} {ID}.fasta .\n",
        "\n",
        "def renum_pdb_str(pdb_str, Ls=None, renum=True, offset=1):\n",
        "  if Ls is not None:\n",
        "    L_init = 0\n",
        "    new_chain = {}\n",
        "    for L,c in zip(Ls, alphabet_list):\n",
        "      new_chain.update({i:c for i in range(L_init,L_init+L)})\n",
        "      L_init += L  \n",
        "\n",
        "  n,num,pdb_out = 0,offset,[]\n",
        "  resnum_ = None\n",
        "  chain_ = None\n",
        "  new_chain_ = new_chain[0]\n",
        "  for line in pdb_str.split(\"\\n\"):\n",
        "    if line[:4] == \"ATOM\":\n",
        "      chain = line[21:22]\n",
        "      resnum = int(line[22:22+5])\n",
        "      if resnum_ is None: resnum_ = resnum\n",
        "      if chain_ is None: chain_ = chain\n",
        "      if resnum != resnum_ or chain != chain_:\n",
        "        num += (resnum - resnum_)  \n",
        "        n += 1\n",
        "        resnum_,chain_ = resnum,chain\n",
        "      if Ls is not None:\n",
        "        if new_chain[n] != new_chain_:\n",
        "          num = offset\n",
        "          new_chain_ = new_chain[n]\n",
        "      N = num if renum else resnum\n",
        "      if Ls is None: pdb_out.append(\"%s%4i%s\" % (line[:22],N,line[26:]))\n",
        "      else: pdb_out.append(\"%s%s%4i%s\" % (line[:21],new_chain[n],N,line[26:]))        \n",
        "  return \"\\n\".join(pdb_out)\n",
        "\n",
        "pdb_str = renum_pdb_str(open(f\"{ID}.pdb\",'r').read(), Ls=lengths)\n",
        "with open(f\"{ID}.pdb\",\"w\") as out:\n",
        "  out.write(pdb_str)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "CFCwEAa2oZEN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown ##Display\n",
        "import py3Dmol\n",
        "\n",
        "\n",
        "pymol_color_list = [\"#33ff33\",\"#00ffff\",\"#ff33cc\",\"#ffff00\",\"#ff9999\",\"#e5e5e5\",\"#7f7fff\",\"#ff7f00\",\n",
        "                    \"#7fff7f\",\"#199999\",\"#ff007f\",\"#ffdd5e\",\"#8c3f99\",\"#b2b2b2\",\"#007fff\",\"#c4b200\",\n",
        "                    \"#8cb266\",\"#00bfbf\",\"#b27f7f\",\"#fcd1a5\",\"#ff7f7f\",\"#ffbfdd\",\"#7fffff\",\"#ffff7f\",\n",
        "                    \"#00ff7f\",\"#337fcc\",\"#d8337f\",\"#bfff3f\",\"#ff7fff\",\"#d8d8ff\",\"#3fffbf\",\"#b78c4c\",\n",
        "                    \"#339933\",\"#66b2b2\",\"#ba8c84\",\"#84bf00\",\"#b24c66\",\"#7f7f7f\",\"#3f3fa5\",\"#a5512b\"]\n",
        "\n",
        "def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,\n",
        "             color=\"pLDDT\", chains=None, vmin=50, vmax=90,\n",
        "             size=(800,480), hbondCutoff=4.0,\n",
        "             Ls=None,\n",
        "             animate=False):\n",
        "  \n",
        "  if chains is None:\n",
        "    chains = 1 if Ls is None else len(Ls)\n",
        "  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])\n",
        "  if animate:\n",
        "    view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})\n",
        "  else:\n",
        "    view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})\n",
        "  if color == \"pLDDT\":\n",
        "    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})\n",
        "  elif color == \"rainbow\":\n",
        "    view.setStyle({'cartoon': {'color':'spectrum'}})\n",
        "  elif color == \"chain\":\n",
        "    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):\n",
        "       view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
        "  if show_sidechains:\n",
        "    BB = ['C','O','N']\n",
        "    HP = [\"ALA\",\"GLY\",\"VAL\",\"ILE\",\"LEU\",\"PHE\",\"MET\",\"PRO\",\"TRP\",\"CYS\",\"TYR\"]\n",
        "    view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
        "                  {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
        "    view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
        "                  {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
        "    view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
        "                  {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})  \n",
        "  if show_mainchains:\n",
        "    BB = ['C','O','N','CA']\n",
        "    view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
        "  view.zoomTo()\n",
        "  if animate: view.animate()\n",
        "  return view\n",
        "\n",
        "color = \"confidence\" #@param [\"confidence\", \"rainbow\", \"chain\"]\n",
        "if color == \"confidence\": color = \"pLDDT\"\n",
        "show_sidechains = False #@param {type:\"boolean\"}\n",
        "show_mainchains = False #@param {type:\"boolean\"}\n",
        "show_pdb(pdb_str, color=color, show_sidechains=show_sidechains, show_mainchains=show_mainchains,\n",
        "         Ls=lengths).show()"
      ],
      "metadata": {
        "cellView": "form",
        "id": "dHl9aS-HHFNQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Download prediction\n",
        "from google.colab import files\n",
        "files.download(f'{ID}.pdb')"
      ],
      "metadata": {
        "cellView": "form",
        "id": "12rxVAHSrmYQ"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}